{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd24eaf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from easydict import EasyDict as ED\n",
    "import torch\n",
    "# ^^^ pyforest auto-imports - don't write above this line\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2a97d69",
   "metadata": {},
   "outputs": [],
   "source": [
    "from cfg import TrainCfg, TrainCfg_ns, ModelCfg, ModelCfg_ns\n",
    "from model import ECG_CRNN_CINC2020\n",
    "from dataset import CINC2020\n",
    "from trainer import CINC2020Trainer\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "from copy import deepcopy\n",
    "\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a95233ff",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b23a1ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "ECG_CRNN_CINC2020.__DEBUG__ = False\n",
    "CINC2020Trainer.__DEBUG__ = False\n",
    "CINC2020.__DEBUG__ = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f8a30bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "TrainCfg_ns.db_dir = \"/home/wenhao/Jupyter/wenhao/data/CinC2021/\"\n",
    "TrainCfg.db_dir = \"/home/wenhao/Jupyter/wenhao/data/CinC2021/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db7d5acc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffb293e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_train = CINC2020(TrainCfg_ns, training=True, lazy=False)\n",
    "ds_val = CINC2020(TrainCfg_ns, training=False, lazy=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3954be0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcc1e222",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "d1980bee",
   "metadata": {},
   "source": [
    "## resnet_nature_comm_bottle_neck_se, 1-linears, AsymmetricLoss, lr=1e-4 to 2e-3, one cycle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aee75514",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_config = deepcopy(TrainCfg_ns)\n",
    "train_config.cnn_name = \"resnet_nature_comm_bottle_neck_se\"\n",
    "train_config.rnn_name = \"none\"\n",
    "train_config.attn_name = \"none\"\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "train_config.n_leads = len(train_config.leads)\n",
    "\n",
    "tranches = train_config.tranches_for_training\n",
    "if tranches:\n",
    "    classes = train_config.tranche_classes[tranches]\n",
    "else:\n",
    "    classes = train_config.classes\n",
    "\n",
    "model_config = deepcopy(ModelCfg_ns)\n",
    "\n",
    "model_config.cnn.name = train_config.cnn_name\n",
    "model_config.rnn.name = train_config.rnn_name\n",
    "model_config.attn.name = train_config.attn_name\n",
    "model_config.clf = ED()\n",
    "model_config.clf.out_channels = [\n",
    "  # not including the last linear layer, whose out channels equals n_classes\n",
    "]\n",
    "model_config.clf.bias = True\n",
    "model_config.clf.dropouts = 0.0\n",
    "model_config.clf.activation = \"mish\"  # for a single layer `SeqLin`, activation is ignored"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "308ac844",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ECG_CRNN_CINC2020(\n",
    "    classes=train_config.classes,\n",
    "    n_leads=train_config.n_leads,\n",
    "    config=model_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca417995",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.module_size_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99025b22",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.device_count() > 1:\n",
    "    model = DP(model)\n",
    "    # model = DDP(model)\n",
    "model.to(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "172f0d27",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = CINC2020Trainer(\n",
    "    model=model,\n",
    "    model_config=model_config,\n",
    "    train_config=train_config,\n",
    "    device=device,\n",
    "    lazy=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b93ad78a",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bacc5011",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer._setup_dataloaders(ds_train, ds_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e1e91a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5281353c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fc2b595",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
